{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/lollcat/fab-torch/blob/master/demo/aldp.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ],
   "id": "2f409620a099e9aa"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Alanine dipeptide: Using the model trained with FAB including a replay buffer"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "40945194ebc75a35"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Install dependencies"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "de0b67176f001146"
  },
  {
   "cell_type": "code",
   "source": [
    "!pip install -q condacolab\n",
    "import condacolab\n",
    "condacolab.install()"
   ],
   "metadata": {
    "id": "UOIXtQ__BqIC"
   },
   "id": "UOIXtQ__BqIC",
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "The kernel should have restarted after executing the cell above. You can just continue by executing the next cell."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7f2e4f6d774ec1b9"
  },
  {
   "cell_type": "code",
   "source": [
    "import condacolab\n",
    "condacolab.install_mambaforge()\n",
    "!mamba install openmm openmmtools"
   ],
   "metadata": {
    "id": "JNYpsxJpBtDP"
   },
   "id": "JNYpsxJpBtDP",
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "!git clone https://github.com/lollcat/fab-torch.git\n",
    "import os\n",
    "os.listdir()\n",
    "os.chdir(\"fab-torch\")\n",
    "!pip install --upgrade .\n",
    "os.chdir(\"demo\")"
   ],
   "metadata": {
    "id": "FKUvWlrsB3dJ"
   },
   "id": "FKUvWlrsB3dJ",
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Downloading and using the model"
   ],
   "metadata": {
    "collapsed": false,
    "id": "f068efcce1dbaf59"
   },
   "id": "f068efcce1dbaf59"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b0c8ad",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:42.594011697Z",
     "start_time": "2024-01-26T17:01:37.908308867Z"
    },
    "id": "a7b0c8ad"
   },
   "outputs": [],
   "source": [
    "# Import packages\n",
    "import urllib\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import mdtraj\n",
    "\n",
    "from tqdm import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from fab.utils.training import load_config\n",
    "from experiments.make_flow.make_aldp_model import make_aldp_model\n",
    "\n",
    "from openmmtools.testsystems import AlanineDipeptideVacuum"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Set precision\n",
    "torch.set_default_dtype(torch.float64)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:44.450207787Z",
     "start_time": "2024-01-26T17:01:44.439786285Z"
    },
    "id": "8b882b2b92a75899"
   },
   "id": "8b882b2b92a75899",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Download config from Hugging Face\n",
    "urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/aldp/config.yaml', 'config.yaml')"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:44.632702250Z",
     "start_time": "2024-01-26T17:01:44.441090404Z"
    },
    "id": "c1e19aede303d50e"
   },
   "id": "c1e19aede303d50e",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Load config and adapt path to\n",
    "config = load_config('config.yaml')\n",
    "\n",
    "config['data']['transform'] = '../' + config['data']['transform']"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:44.636434508Z",
     "start_time": "2024-01-26T17:01:44.628242436Z"
    },
    "id": "3fc26b6acfdce425"
   },
   "id": "3fc26b6acfdce425",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Create model from config\n",
    "model = make_aldp_model(config, device)"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:44.883956731Z",
     "start_time": "2024-01-26T17:01:44.636026351Z"
    },
    "id": "aea4aba01583fd7b"
   },
   "id": "aea4aba01583fd7b",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Download model from Hugging Face\n",
    "urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/aldp/model.pt', 'model.pt')"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:45.529464199Z",
     "start_time": "2024-01-26T17:01:44.893745450Z"
    },
    "id": "f802aa89ef3a76ff"
   },
   "id": "f802aa89ef3a76ff",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Load model weights\n",
    "model.load('model.pt', map_location=device)"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:01:45.564693805Z",
     "start_time": "2024-01-26T17:01:45.526622626Z"
    },
    "id": "ee8d147bd79deefb"
   },
   "id": "ee8d147bd79deefb",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Generate model samples\n",
    "batch_size = 100\n",
    "n_batches = 1000\n",
    "\n",
    "x_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for _ in tqdm(range(n_batches)):\n",
    "        # Generate samples in internal coordinates\n",
    "        z = model.flow.sample((batch_size,))\n",
    "        # Transform samples to Cartesian coordinates\n",
    "        x, _ = model.target_distribution.coordinate_transform(z)\n",
    "        x_list.append(x.cpu().numpy())\n",
    "\n",
    "x = np.concatenate(x_list)"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:03:44.058286702Z",
     "start_time": "2024-01-26T17:01:45.565284637Z"
    },
    "id": "b6b09e8c4ec160ed"
   },
   "id": "b6b09e8c4ec160ed",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Get dihedral angles phi and psi for Ramachandran plot\n",
    "aldp = AlanineDipeptideVacuum()\n",
    "topology = mdtraj.Topology.from_openmm(aldp.topology)\n",
    "traj = mdtraj.Trajectory(x.reshape(-1, 22, 3), topology)\n",
    "\n",
    "phi = mdtraj.compute_phi(traj)[1].reshape(-1)\n",
    "psi = mdtraj.compute_psi(traj)[1].reshape(-1)"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:03:44.106214544Z",
     "start_time": "2024-01-26T17:03:44.058025851Z"
    },
    "id": "234855e49ebb54ea"
   },
   "id": "234855e49ebb54ea",
   "execution_count": null
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "# Ramachandran plot\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.plot(phi, psi, '.', alpha=0.1)\n",
    "plt.xlim([-np.pi, np.pi])\n",
    "plt.ylim([-np.pi, np.pi])\n",
    "plt.gca().set_box_aspect(1)\n",
    "plt.show()"
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-26T17:03:44.269311099Z",
     "start_time": "2024-01-26T17:03:44.106050360Z"
    },
    "id": "daa3c59a253da4db"
   },
   "id": "daa3c59a253da4db",
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "colab": {
   "provenance": [],
   "include_colab_link": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
